Skip to content

Add typing friendly transformation interfaces#90

Open
wdhongtw wants to merge 1 commit intomainfrom
transform-typing
Open

Add typing friendly transformation interfaces#90
wdhongtw wants to merge 1 commit intomainfrom
transform-typing

Conversation

@wdhongtw
Copy link
Copy Markdown
Collaborator

For torch_view and jax_view, expose a <func>_elem variant that easier for user code to leverage typing analysis.

torch_view and jax_view are super powerful, but also quite hard to annotate it's function signature, as they support arbitrary tree structure. Users need to manually typing.cast in order to this powerful interface, while still have the type being inferenced.

In this PR, we try to expose another interface, that's less powerful, but friendlier for typing system.
As long as all the input given with their types, the computation results can all have correct types.
Typing analysis can works properly in this case. Also, IDE can provide more help for developer, e.g. method auto-complete, jump-to-definition ... etc.

import torch
import torchax.interop

def add(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    a_j = torchax.interop.jax_view(a)
    b_j = torchax.interop.jax_view(b)
    c_j = a_j + b_j
    c = torchax.interop.torch_view(c_j)
    # Although the logic is correct, a_j, b_j, c_j and c all being Any here.
    return c


def sub(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    a_j = torchax.interop.jax_view_elem(a)
    b_j = torchax.interop.jax_view_elem(b)
    c_j = a_j - b_j
    c = torchax.interop.torch_view_elem(c_j)
    # They can correctly being inferred as jax.Array and torch.Tensor
    return c

Signed-off-by: Weida Hong <wdhongtw@google.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants